/*
 * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved.
 *
 * @brief load data instruction ut for ascend910B1
 *
 */
#include <gtest/gtest.h>
#include "kernel_operator.h"
#include "lib/matmul/tiling.h"
#include "impl/matmul/modules/matmul_param.h"
#include "impl/matmul/modules/matmul_policy.h"
#define private public
#include "impl/matmul/modules/stage/copy_cube_in/copy_cube_in_params.h"
#include "impl/matmul/modules/matmul_private_modules.h"
#include "base_tiling_struct.h"

using namespace std;
using namespace AscendC;

namespace {
template <class A_TYPE, class B_TYPE, class C_TYPE, class BIAS_TYPE, const MatmulConfig& MM_CFG, class MM_CB,
MATMUL_POLICY_DEFAULT_OF(MatmulPolicy)>
class MatmulImpl
: MATMUL_IMPORT_MODULE(NLoop)
, MATMUL_IMPORT_MODULE(KLoop)
, MATMUL_IMPORT_MODULE_PRIVATE(CopyCubeInParamsB)
, MATMUL_IMPORT_MODULE_PRIVATE(MatmulShapeInfo)
, MATMUL_IMPORT_MODULE_PRIVATE(MatmulShapeTiling)
{
    MATMUL_ALLOW_USING_PRIVATE(CopyCubeInParamsB);
    MATMUL_ALLOW_USING_PRIVATE(MatmulShapeInfo);
    MATMUL_ALLOW_USING_PRIVATE(MatmulShapeTiling);
    MATMUL_ALLOW_USING(NLoop);
    MATMUL_ALLOW_USING(KLoop);

public:
    using IMPL = MatmulImpl<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE, MM_CFG, MM_CB, MATMUL_POLICY>;
    using VAR_PARAMS =
        typename Impl::Detail::MatmulParams<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE, MM_CFG, GetMatmulVersion(MM_CFG)>::PARAMS;

    MATMUL_USE_MODULE(NLoop);
    MATMUL_USE_MODULE(KLoop);
    MatmulImpl() {}

    VAR_PARAMS& GetVar() {
        return var;
    }

    void InitVar(const TCubeTiling &tiling) {
        var.tiling_.SetTiling(&tiling);
        var.tpipe_ = &pipe;
    }

    void SetRuntimeParams(int32_t baseUseK, int32_t baseUseN, int32_t stepKbIdx = 0, int32_t stepNIdx = 0, bool isTranspose = false) {
        // var.baseUseN_ = baseUseN;
        // var.baseUseK_ = baseUseK;
        // var.nIter_ = var.tiling_.GetSingleCoreN() / var.tiling_.GetBaseN();
        // var.kIter_ = Ceil(var.tiling_.GetSingleCoreK(), var.tiling_.GetBaseK());
        var.singleCoreM_ = var.tiling_.GetSingleCoreM();
        var.singleCoreN_ = var.tiling_.GetSingleCoreN();
        var.singleCoreK_ = var.tiling_.GetSingleCoreK();
        N_ = var.tiling_.GetSingleCoreM();
        Kb_ = var.tiling_.GetSingleCoreK();
        var.isTransposeB_ = isTranspose;
        MATMUL_MODULE(KLoop)->Init(var.singleCoreK_);
        MATMUL_MODULE(NLoop)->Init(var.singleCoreN_);
        var.isB1KFullLoad_ = (var.tiling_.GetStepKb() >= MATMUL_MODULE(KLoop)->GetTotalIter());
        // var.stepKbIdx_ = stepKbIdx;
        // var.stepNIdx_ = stepNIdx;
        // if constexpr (DoMatmulMDL(MM_CFG) || DoMatmulSpecialMDL(MM_CFG)) {
        //     var.nStepIter_ = Ceil(var.singleCoreN_, var.tiling_.GetBaseN() * var.tiling_.GetStepN());
        //     var.tailStepN_ = var.singleCoreN_ % (var.tiling_.GetBaseN() * var.tiling_.GetStepN());
        //     if (var.tailStepN_ == 0) {
        //         var.tailStepN_ = var.tiling_.GetBaseN() * var.tiling_.GetStepN();
        //     }
        //     var.kbStepIter_ = Ceil(var.singleCoreK_, var.tiling_.GetBaseK() * var.tiling_.GetStepKb());
        //     var.tailStepKb_ = var.singleCoreK_ % (var.tiling_.GetBaseK() * var.tiling_.GetStepKb());
        //     if (var.tailStepKb_ == 0) {
        //         var.tailStepKb_ = var.tiling_.GetBaseK() * var.tiling_.GetStepKb();
        //     }
        // }
    }

private:
    TPipe pipe;
    VAR_PARAMS var;
    int32_t M_;
    int32_t N_;
    int32_t Ka_;
    int32_t Kb_;
};
}

class TestCopyCubeInParams : public testing::Test {
protected:
    void SetUp() {}
    void TearDown() {}

private:
    using A_TYPE = MatmulType<AscendC::TPosition::GM, CubeFormat::ND, int8_t, false>;
    using B_TYPE = MatmulType<AscendC::TPosition::GM, CubeFormat::ND, int8_t, false>;
    using C_TYPE = MatmulType<AscendC::TPosition::GM, CubeFormat::ND, float>;
    using BIAS_TYPE = MatmulType<AscendC::TPosition::GM, CubeFormat::ND, float>;

    MatmulImpl<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE, CFG_MDL, void> mm;
};

TEST_F(TestCopyCubeInParams, all_interface) {
    TilingParams tilingParams = {1, 64, 48, 256, 64, 48, 256, 32, 48, 96, 2, 4, 1, 2, 1, 3, 1, 1};
    TCubeTiling tiling;
    tilingParams.GetTiling(tiling);
    int32_t baseUseK = 96;
    int32_t baseUseN = 48;
    int32_t stepKbIdx = 0;
    int32_t stepNIdx = 1;
    mm.InitVar(tiling);
    mm.SetRuntimeParams(baseUseK, baseUseN, stepKbIdx, stepNIdx);
    // for tmp
    EXPECT_EQ(mm.GetStepCol(), 2);
    EXPECT_EQ(mm.GetStepRow(), 3);
    EXPECT_EQ(mm.GetBufferPos(), stepKbIdx);
    EXPECT_EQ(mm.IsKRowDirec(), true);
    EXPECT_EQ(mm.GetOrgHeight(), 256);
    EXPECT_EQ(mm.GetOrgWidth(), 64);
    EXPECT_EQ(mm.GetBaseHeight(), 96);
    EXPECT_EQ(mm.GetBaseWidth(), 48);
    EXPECT_EQ(mm.GetSingleHeight(), 256);
    EXPECT_EQ(mm.GetSingleWidth(), 48);
    // EXPECT_EQ(mm.GetTotalRow(), 3);
    // EXPECT_EQ(mm.GetTotalCol(), 1);
    EXPECT_EQ(mm.GetBufferSize(), 96 * 64);
    EXPECT_EQ(mm.GetDepth(), 4);
}
